#!/usr/bin/env python3

import getopt
import os
import re
import select
import signal
import sys
import syslog
import time
from collections import defaultdict
import traceback

from swsscommon import swsscommon

from supervisor import childutils

# Each line of this file should specify one process, (as defined in supervisord.conf file), in the
# following format:
#
# program:<process_name>
WATCH_PROCESSES_FILE = '/etc/supervisor/watchdog_processes'

# Each line of this file should specify either one critical process or one
# critical process group, (as defined in supervisord.conf file), in the
# following format:
#
# program:<process_name>
# group:<group_name>
CRITICAL_PROCESSES_FILE = '/etc/supervisor/critical_processes'

# The FEATURE table in config db contains auto-restart field
FEATURE_TABLE_NAME = 'FEATURE'

# The HEARTBEAT table in config db contains heart beat config
HEARTBEAT_TABLE_NAME = 'HEARTBEAT'

# Value of parameter 'timeout' in select(...) method
SELECT_TIMEOUT_SECS = 1.0

# Alerting message will be written into syslog in the following interval
ALERTING_INTERVAL_SECS = 60

EVENTS_PUBLISHER_SOURCE = "sonic-events-host"
EVENTS_PUBLISHER_TAG = "process-exited-unexpectedly"

heartbeat_alert_interval_initialized = False
heartbeat_alert_interval_mapping = defaultdict(dict)

def get_group_and_process_list(process_file):
    """
    @summary: Read the critical processes/group names.
    @return: Two lists which contain critical processes and group names respectively.
    """
    group_list = []
    process_list = []

    with open(process_file, 'r') as file:
        for line in file:
            # ignore blank lines
            if re.match(r"^\s*$", line):
                continue
            line_info = line.strip(' \n').split(':')
            if len(line_info) != 2:
                syslog.syslog(syslog.LOG_ERR,
                              "Syntax of the line {} in processes file is incorrect. Exiting...".format(line))
                sys.exit(5)

            identifier_key = line_info[0].strip()
            identifier_value = line_info[1].strip()
            if identifier_key == "group" and identifier_value:
                group_list.append(identifier_value)
            elif identifier_key == "program" and identifier_value:
                process_list.append(identifier_value)
            else:
                syslog.syslog(syslog.LOG_ERR,
                              "Syntax of the line {} in processes file is incorrect. Exiting...".format(line))
                sys.exit(6)

    return group_list, process_list


def generate_alerting_message(process_name, status, dead_minutes, priority=syslog.LOG_ERR):
    """
    @summary: If a critical process was not running, this function will determine it resides in host
              or in a specific namespace. Then an alerting message will be written into syslog.
    """
    namespace_prefix = os.environ.get("NAMESPACE_PREFIX")
    namespace_id = os.environ.get("NAMESPACE_ID")

    if not namespace_prefix or not namespace_id:
        namespace = "host"
    else:
        namespace = namespace_prefix + namespace_id

    syslog.syslog(priority, "Process '{}' is {} in namespace '{}' ({} minutes)."
                  .format(process_name, status, namespace, dead_minutes))


def get_autorestart_state(container_name, use_unix_socket_path):
    """
    @summary: Read the status of auto-restart feature from Config_DB.
    @return: Return the status of auto-restart feature.
    """
    config_db = swsscommon.ConfigDBConnector(use_unix_socket_path=use_unix_socket_path)
    config_db.connect()
    features_table = config_db.get_table(FEATURE_TABLE_NAME)
    if not features_table:
        syslog.syslog(syslog.LOG_ERR, "Unable to retrieve features table from Config DB. Exiting...")
        sys.exit(2)

    if container_name not in features_table:
        syslog.syslog(syslog.LOG_ERR, "Unable to retrieve feature '{}'. Exiting...".format(container_name))
        sys.exit(3)

    is_auto_restart = features_table[container_name].get('auto_restart')
    if not is_auto_restart:
        syslog.syslog(
            syslog.LOG_ERR, "Unable to determine auto-restart feature status for '{}'. Exiting...".format(container_name))
        sys.exit(4)

    return is_auto_restart

def load_heartbeat_alert_interval(use_unix_socket_path):
    config_db = swsscommon.ConfigDBConnector(use_unix_socket_path=use_unix_socket_path)
    config_db.connect()
    heartbeat_table = config_db.get_table(HEARTBEAT_TABLE_NAME)
    if heartbeat_table:
        heartbeat_table_keys = heartbeat_table.keys()
        for process in heartbeat_table_keys:
            alert_interval = heartbeat_table[process].get('alert_interval')
            if alert_interval:
                heartbeat_alert_interval_mapping[process] = int(alert_interval) / 1000

    heartbeat_alert_interval_initialized = True

def get_heartbeat_alert_interval(process, use_unix_socket_path):
    if not heartbeat_alert_interval_initialized:
        load_heartbeat_alert_interval(use_unix_socket_path)
    if process in heartbeat_alert_interval_mapping:
        return heartbeat_alert_interval_mapping[process]

    return ALERTING_INTERVAL_SECS

def publish_events(events_handle, process_name, container_name):
    params = swsscommon.FieldValueMap()
    params["process_name"] = process_name
    params["ctr_name"] = container_name
    swsscommon.event_publish(events_handle, EVENTS_PUBLISHER_TAG, params)

def main(argv):
    container_name = None
    use_unix_socket_path = False
    opts, args = getopt.getopt(argv, "c:s", ["container-name=", "use-unix-socket-path"])
    for opt, arg in opts:
        if opt in ("-c", "--container-name"):
            container_name = arg
        if opt in ("-s", "--use-unix-socket-path"):
            use_unix_socket_path = True

    if not container_name:
        syslog.syslog(syslog.LOG_ERR, "Container name not specified. Exiting...")
        sys.exit(1)

    critical_group_list, critical_process_list = get_group_and_process_list(CRITICAL_PROCESSES_FILE)

    # WATCH_PROCESSES_FILE is optional
    watch_process_list = []
    if os.path.exists(WATCH_PROCESSES_FILE):
        _, watch_process_list = get_group_and_process_list(WATCH_PROCESSES_FILE)

    process_under_alerting = defaultdict(dict)
    process_heart_beat_info = defaultdict(dict)
    # Transition from ACKNOWLEDGED to READY
    childutils.listener.ready()
    events_handle = swsscommon.events_init_publisher(EVENTS_PUBLISHER_SOURCE)
    while True:
        file_descriptor_list = select.select([sys.stdin], [], [], SELECT_TIMEOUT_SECS)[0]
        if len(file_descriptor_list) > 0:
            line = file_descriptor_list[0].readline()
            headers = childutils.get_headers(line)
            # Check if 'len' exists before using it
            if 'len' not in headers:
                syslog.syslog(syslog.LOG_WARNING, "Missing 'len' in headers: {}, line: {}".format(headers, line))
                continue
            payload = sys.stdin.read(int(headers['len']))

            # Handle the PROCESS_STATE_EXITED event
            if headers['eventname'] == 'PROCESS_STATE_EXITED':
                payload_headers, payload_data = childutils.eventdata(payload + '\n')

                expected = int(payload_headers['expected'])
                process_name = payload_headers['processname']
                group_name = payload_headers['groupname']

                if (process_name in critical_process_list or group_name in critical_group_list) and expected == 0:
                    is_auto_restart = get_autorestart_state(container_name, use_unix_socket_path)
                    if is_auto_restart != "disabled":
                        MSG_FORMAT_STR = "Process '{}' exited unexpectedly. Terminating supervisor '{}'"
                        msg = MSG_FORMAT_STR.format(payload_headers['processname'], container_name)
                        syslog.syslog(syslog.LOG_INFO, msg)
                        publish_events(events_handle, payload_headers['processname'], container_name)
                        swsscommon.events_deinit_publisher(events_handle)
                        os.kill(os.getppid(), signal.SIGTERM)
                    else:
                        process_under_alerting[process_name]["last_alerted"] = time.time()
                        process_under_alerting[process_name]["dead_minutes"] = 0

            # Handle the PROCESS_STATE_RUNNING event
            elif headers['eventname'] == 'PROCESS_STATE_RUNNING':
                payload_headers, payload_data = childutils.eventdata(payload + '\n')
                process_name = payload_headers['processname']

                if process_name in process_under_alerting:
                    process_under_alerting.pop(process_name)

            # Handle the PROCESS_COMMUNICATION_STDOUT event
            elif headers['eventname'] == 'PROCESS_COMMUNICATION_STDOUT':
                payload_headers, payload_data = childutils.eventdata(payload + '\n')
                process_name = payload_headers['processname']

                # update process heart beat time
                if (process_name in watch_process_list):
                    process_heart_beat_info[process_name]["last_heart_beat"] = time.time()

            # Transition from BUSY to ACKNOWLEDGED
            childutils.listener.ok()

            # Transition from ACKNOWLEDGED to READY
            childutils.listener.ready()

        # Check whether we need write alerting messages into syslog
        for process_name in process_under_alerting.keys():
            epoch_time = time.time()
            elapsed_secs = epoch_time - process_under_alerting[process_name]["last_alerted"]
            if elapsed_secs >= ALERTING_INTERVAL_SECS:
                elapsed_mins = elapsed_secs // 60
                process_under_alerting[process_name]["last_alerted"] = epoch_time
                process_under_alerting[process_name]["dead_minutes"] += elapsed_mins
                generate_alerting_message(process_name, "not running", process_under_alerting[process_name]["dead_minutes"])

        # Check whether we need write alerting messages into syslog
        for process in process_heart_beat_info.keys():
            epoch_time = time.time()
            elapsed_secs = epoch_time - process_heart_beat_info[process]["last_heart_beat"]
            threshold = get_heartbeat_alert_interval(process, use_unix_socket_path)
            if threshold > 0 and elapsed_secs >= threshold:
                elapsed_mins = elapsed_secs // 60
                generate_alerting_message(process, "stuck", elapsed_mins, syslog.LOG_WARNING)

if __name__ == "__main__":
    try:
        main(sys.argv[1:])
    except Exception as e:
        syslog.syslog(syslog.LOG_ERR, "Exception: {}, trace: {}".format(e, traceback.format_exc()))
        sys.exit(1)
